#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('./MNRAS_Style.mplstyle')

NEUTRINOS = [58018.5]

class Gamma():
    def __init__(self):
        self.jd = 0
        self.flux = 0
        self.flux_err = 0
        self.upp = 0

def getGammaData(obj):
    gamma = []
    fop = open(os.path.join('../../data/2d/',obj+'_172800_lc.csv'))
    for line in fop.readlines():
        if line.startswith("#"):
            continue
        sl = line.split(',')
        if float(sl[2]) >= 1e-5:
            continue
        obs = Gamma()
        obs.jd = float(sl[0])
        if obs.jd > 59976.5:
            continue
        TS       = float(sl[5])
        if  TS > 4:
            obs.flux     = float(sl[2])
            obs.flux_err = float(sl[3])
        else:
            obs.upp = float(sl[4])
        gamma.append(obs)
    fop.close()
    return gamma

def readCD(obj):
    CDs = []
    fop = open(obj+'_CD.dat')
    for line in fop.readlines():
        sl = line.split(',')
        MJD = float(sl[0])
        CD = float(sl[1])
        CD_err = float(sl[2])
        tel_band = sl[3].rstrip('\n')
        CDs.append((MJD,CD,CD_err,tel_band))
    fop.close()
    return CDs

def plot_w_CD(obj, CDs, yscale='linear'):
    data = getGammaData(obj)
    fluxes = []
    uppers = []
    for k,o in enumerate(data):
        if o.upp == 0:
            fluxes.append(o)
        else:
            uppers.append(o)
    
    fig, ax = plt.subplots(2, 1, sharex=True)
    fig.set_size_inches(17,10)
    
    plt.rc('text', usetex=True)
    ax[0].set_ylabel(r'Flux [${\rm10^{-6} ph\,\, sec^{-1}\, cm^{-2}}$]')

    max_fl = 0
    min_jd = 9e9
    max_jd = 0
    for o in fluxes:
        ax[0].errorbar([o.jd], [1e6 * o.flux], yerr=[1e6 * o.flux_err], markersize=0.5, elinewidth=0.3, color="k")
        ax[0].plot([o.jd],[1e6 * o.flux],'ko', markersize=1)
        max_fl = max(max_fl,1e6 * o.flux)
        min_jd = min(min_jd,o.jd)
        max_jd = max(max_jd,o.jd)

    for u in uppers:
        ax[0].errorbar([u.jd],[1e6 * u.upp], yerr=-1e-6, uplims=True, capsize=1.8, markeredgewidth=0, color='#AA3377')
        min_jd = min(min_jd,u.jd)
        max_jd = max(max_jd,u.jd)

    max_fl = max(0,max_fl)
    max_fl = max_fl*1.1

    plt.yscale(yscale)
    if yscale == 'log':
        ax[0].set_ylim([1e-3,max_fl])
    else:
        ax[0].set_ylim([0,max_fl])
    ax[0].set_xlim([min_jd,max_jd])
    
    ax[0].vlines(55510, 0, 0.8, color = '#4477AA')
    ax[0].vlines(55700, 0, 0.8, color = '#4477AA')
    ax[0].text(55567, 0.6, '1', fontsize = 30, color = '#4477AA')
    ax[0].vlines(57872.2,  0, 0.8, color = '#4477AA')
    ax[0].vlines(58083.3, 0, 0.8, color = '#4477AA')
    ax[0].text(57750,  0.6, '2', fontsize = 30, color = '#4477AA')
    ax[0].arrow(57840, 0.625, 75, 0, head_width = 0.015, head_length=11, width=0.005, facecolor = '#4477AA', edgecolor='#4477AA')

    # fill area of excess from https://www.science.org/doi/pdf/10.1126/science.aat2890
    ax[0].axvspan(56937.81, 57096.21, color='#ffb6c1')

    for JDneut in NEUTRINOS:
        ax[0].arrow(JDneut, 0.8, 0, -0.2, width = 5, head_width = 23, head_length = 0.013, facecolor = '#228833', edgecolor='#228833', zorder=3)

    ax2 = ax[0].twiny()
    ax2.set_xlim(ax[0].get_xlim())
    ax2.set_xticks(      [54832, 55197, 55562, 55927, 56293, 56658, 57023, 57388, 57754, 58119, 58484, 58849, 59215, 59580, 59945])
    ax2.set_xticklabels(['2009','2010','2011','2012','2013','2014','2015','2016','2017','2018','2019','2020','2021','2022','2023'])
    ax2.grid(None)

    colors = {'ASAS-SN g'    : "#332288",
              'ASAS-SN V'    : "#88CCEE",
              'Pan-STARRS g' : "#44AA99",
              'Pan-STARRS r' : "#117733",
              'Pan-STARRS i' : "#AA4499",
              'SMARTS R'     : "#999933",
              'ZTF g'        : "#DDCC77",
              'ZTF r'        : "#CC6677",
              'ZTF i'        : "#882255"}

    plotted_flag = {'ASAS-SN g'    : 0,
                    'ASAS-SN V'    : 0,
                    'Pan-STARRS g' : 0,
                    'Pan-STARRS r' : 0,
                    'Pan-STARRS i' : 0,
                    'SMARTS R'     : 0,
                    'ZTF g'        : 0,
                    'ZTF r'        : 0,
                    'ZTF i'        : 0}
    ax[1].set_ylim([0.5,25])
    ax[1].set_yscale('log')
    
    CD_before = []
    CD_during = []
    CD_after  = []
    DT = 58083.3 - 57872.2                # delta t during the second pattern
    LEFT_EDGE_BEFORE = 57872.2 - DT       # beginning of the second pattern minus DT
    RIGHT_EDGE_AFTER = 58083.3 + DT       # end of the second pattern plus DT
    for CD in CDs:
        MJD,CD_val,CD_val_err,tel_band = CD
        ax[1].errorbar([MJD], [CD_val], yerr=[CD_val_err], markersize=1, elinewidth=0.3, color=colors[tel_band], zorder=3)
        if plotted_flag[tel_band] > 0:
            # this is for the label in legend
            ax[1].plot([MJD],[CD_val], 'o', markersize=3, color=colors[tel_band], zorder=3)
        else:
            ax[1].plot([MJD],[CD_val], 'o', markersize=3, color=colors[tel_band], label=tel_band, zorder=3)
            plotted_flag[tel_band] = 1
        
        if LEFT_EDGE_BEFORE <= MJD and MJD < 57872.2:
            CD_before.append(CD_val)
        elif 57872.2 <= MJD and MJD < 58083.3:
            CD_during.append(CD_val)
        elif 58083.3 <= MJD and MJD < RIGHT_EDGE_AFTER:
            CD_after.append(CD_val)

    ax[1].vlines(55510, ax[1].get_ylim()[0], ax[1].get_ylim()[1], color = '#4477AA', zorder=4)
    ax[1].vlines(55700, ax[1].get_ylim()[0], ax[1].get_ylim()[1], color = '#4477AA', zorder=4)
    
    ax[1].vlines(57872.2, ax[1].get_ylim()[0], ax[1].get_ylim()[1], color = '#4477AA', zorder=4)
    ax[1].vlines(58083.3, ax[1].get_ylim()[0], ax[1].get_ylim()[1], color = '#4477AA', zorder=4)

    
    y_bot = np.mean(CD_before) - np.std(CD_before)/np.sqrt(len(CD_before))
    y_top = np.mean(CD_before) + np.std(CD_before)/np.sqrt(len(CD_before))
    ax[1].fill_between([LEFT_EDGE_BEFORE, 57872.2], [y_bot, y_bot], [y_top, y_top], color='#F7F056', zorder=1)
    ax[1].hlines(np.mean(CD_before), LEFT_EDGE_BEFORE, 57872.2, color = '#F1932D',zorder=2)
    
    y_bot = np.mean(CD_during) - np.std(CD_during)/np.sqrt(len(CD_during))
    y_top = np.mean(CD_during) + np.std(CD_during)/np.sqrt(len(CD_during))
    ax[1].fill_between([57872.2, 58083.3], [y_bot, y_bot], [y_top, y_top], color='#F7F056', zorder=1)
    ax[1].hlines(np.mean(CD_during), 57872.2, 58083.3, color = '#F1932D',zorder=2)
    
    y_bot = np.mean(CD_after) - np.std(CD_after)/np.sqrt(len(CD_after))
    y_top = np.mean(CD_after) + np.std(CD_after)/np.sqrt(len(CD_after))
    ax[1].fill_between([58083.3, RIGHT_EDGE_AFTER], [y_bot, y_bot], [y_top, y_top], color='#F7F056', zorder=1)
    ax[1].hlines(np.mean(CD_after),  58083.3, RIGHT_EDGE_AFTER, color = '#F1932D',zorder=2)

    ax[1].set_xlabel('MJD')
    ax[1].set_ylabel('CD')
    
    ax[1].legend(loc=3,ncol=3)
    plt.savefig(obj + '_all.pdf',bbox_inches='tight')
    plt.clf()
    plt.cla()
    plt.close()
    return

if __name__ == "__main__":
    CDs = readCD('J0509.4+0542')
    plot_w_CD('J0509.4+0542',CDs,'linear')
